#!/usr/bin/python3

from unittest import TestCase, main, skipUnless

from fathom.errors import FathomError
from fathom.schema import Table, Database, Column

from fathomtools import get_database, get_database_type
from fathomtools.diff import DatabaseDiff, UNCHANGED, CREATED, ALTERED, DROPPED

try:
    import psycopg2
    postgres_errors = (psycopg2.OperationalError, psycopg2.ProgrammingError, 
                       psycopg2.InternalError)
    TEST_POSTGRES = True
except ImportError:
    postgres_errors = ()
    TEST_POSTGRES = False
    
try:
    import sqlite3
    TEST_SQLITE = True
except ImportError:
    TEST_SQLITE = False

TEST_MYSQL = True
try:
    import MySQLdb
    mysql_module = MySQLdb
    mysql_errors = (mysql_module.OperationalError, 
                    mysql_module.ProgrammingError)
except ImportError:
    try:
        import pymysql
        mysql_module = pymysql
        mysql_errors = (mysql_module.OperationalError, 
                        mysql_module.ProgrammingError,
                        mysql_module.err.InternalError)
    except ImportError:
        mysql_errors = ()
        TEST_MYSQL = False
        
try:
    import cx_Oracle
    oracle_errors = (cx_Oracle.DatabaseError,)
    TEST_ORACLE = True
except ImportError:
    oracle_errors = ()
    TEST_ORACLE = False

class DatabaseTypeTestCase(TestCase):
    
    @skipUnless(TEST_SQLITE, 'Failed to import sqlite3 module.')
    def test_sqlite(self):
        self.assertEqual(get_database_type('fathom.db3'), 'Sqlite3')
    
    @skipUnless(TEST_MYSQL, 'Failed to import pymysql module.')
    def test_mysql(self):
        self.assertEqual(get_database_type(user='fathom', db='fathom'), 'MySQL')
        
    @skipUnless(TEST_POSTGRES, 'Failed to import psycopg2 module.')
    def test_postgres(self):
        self.assertEqual(get_database_type('dbname=fathom user=fathom'),
                         'PostgreSQL')
    
    @skipUnless(TEST_ORACLE, 'Failed to import cx_Oracle module.')
    def test_oracle(self):
        self.assertEqual(get_database_type('fathom', 'fathom'), 'Oracle')
    
    def test_exception(self):
        self.assertRaises(FathomError, get_database_type, 
                          'non_existing_file.db')
        self.assertRaises(FathomError, get_database_type, user='fathom', 
                          db='non_existing_database')
        self.assertRaises(FathomError, get_database_type, 
                          'dbname=not_existing_database user=fathom')
        self.assertRaises(FathomError, get_database_type,
                          'fathom', 'wrong_password')


class DatabaseDiffTestCase(TestCase):
    
    STATE_STRINGS = {UNCHANGED: 'UNCHANGED', CREATED: 'CREATED', 
                     DROPPED: 'DROPPED', ALTERED: 'ALTERED'}

    def setUp(self):

        self.table1 = Table('table_1')
        self.table1.columns = {}
        self.table2 = Table('table_2')
        self.table2.columns = {}

        self.base_db = Database(name='base')
        self.dest_db = Database(name='dest')

        self.base_db.tables = {self.table1.name: self.table1}

        self.more_tables_db = Database(name='more_tables_db')
        self.more_tables_db.tables = {self.table1.name: self.table1, self.table2.name : self.table2}

    def assertState(self, item, state):
        if item.state != state:
            raise AssertionError("item state is: %s, expecting %s" % 
                                 (self.STATE_STRINGS[item.state], 
                                  self.STATE_STRINGS[state]))


    def test_new_table(self): 
        diff = DatabaseDiff(self.base_db, self.more_tables_db)
        
        diff_tables = diff.tables
        self.assertTrue(self.table1.name in diff_tables)
        self.assertTrue(self.table2.name in diff_tables)
        
        unchanged_table = diff_tables[self.table1.name]
        created_table = diff_tables[self.table2.name]
        self.assertState(unchanged_table, UNCHANGED)
        self.assertState(created_table, CREATED)

    def test_dropped_table(self):
        diff = DatabaseDiff(self.more_tables_db, self.base_db)
        
        diff_tables = diff.tables
        self.assertTrue(self.table1.name in diff_tables)
        self.assertTrue(self.table2.name in diff_tables)

        unchanged_table = diff_tables[self.table1.name]
        dropped_table  = diff_tables[self.table2.name]
        self.assertState(unchanged_table, UNCHANGED)
        self.assertState(dropped_table, DROPPED)

    def test_same_tables(self):
        diff = DatabaseDiff(self.base_db, self.base_db)
    
        diff_tables = diff.tables        
        self.assertTrue(self.table1.name in diff_tables)
    
        unchanged_table = diff_tables[self.table1.name]
        self.assertState(unchanged_table, UNCHANGED)

    def test_new_columns(self):
        col_1 = Column('col_1', 'varchar(10)')
        col_2 = Column('col_2', 'varchar(10)')

        table_name = 'table_1'

        base_table = Table(table_name)
        base_table.columns = {'col_1': col_1}
          
        more_columns_table = Table(table_name)
        more_columns_table.columns = {'col_1': col_1, 'col_2': col_2}

        source_db = Database(name='base')
        source_db.tables = {table_name : base_table}

        dest_db = Database(name='dest')
        dest_db.tables = {table_name: more_columns_table}

        diff = DatabaseDiff(source_db, dest_db)
        diff_tables = diff.tables
        
        self.assertTrue(table_name in diff_tables)
        self.assertState(diff_tables[table_name],ALTERED)
        self.assertTrue('col_1' in diff_tables[table_name].columns)
        self.assertState(diff_tables[table_name].columns['col_1'],UNCHANGED)
        self.assertTrue('col_2' in diff_tables[table_name].columns)
        self.assertState(diff_tables[table_name].columns['col_2'],CREATED)

    def test_remove_columns(self):
        col_1 = Column('col_1', 'varchar(10)')
        col_2 = Column('col_2', 'varchar(10)')

        table_name = 'table_1'

        base_table = Table(table_name)
        base_table.columns = {'col_1': col_1}
          
        more_columns_table = Table(table_name)
        more_columns_table.columns = {'col_1': col_1, 'col_2': col_2}

        source_db = Database(name='base')
        source_db.tables = {table_name : more_columns_table}

        dest_db = Database(name='dest')
        dest_db.tables = {table_name: base_table}

        diff = DatabaseDiff(source_db, dest_db)
        diff_tables = diff.tables
        
        self.assertTrue(table_name in diff_tables)
        self.assertState(diff_tables[table_name],ALTERED)
        self.assertTrue('col_1' in diff_tables[table_name].columns)
        self.assertState(diff_tables[table_name].columns['col_1'],UNCHANGED)
        self.assertTrue('col_2' in diff_tables[table_name].columns)
        self.assertState(diff_tables[table_name].columns['col_2'],DROPPED)
        
    def test_changed_column1(self):
        source_table = Table('table')
        source_table.columns = {'col_1': Column('col_1', 'varchar(10)')}
        
        dest_table = Table('table')
        dest_table.columns = {'col_1': Column('col_1', 'varchar(16)')}
        
        self.base_db.tables = {'table': source_table}
        self.dest_db.tables = {'table': dest_table}
        
        diff = DatabaseDiff(self.base_db, self.dest_db)
        self.assertTrue('table' in diff.tables)
        self.assertState(diff.tables['table'], ALTERED)
        self.assertTrue('col_1' in diff.tables['table'].columns)
        self.assertState(diff.tables['table'].columns['col_1'], ALTERED)
        
    def test_changed_column2(self):
        source_table = Table('table')
        source_table.columns = {'col_1': Column('col_1', 'varchar(10)',
                                                not_null=True)}
        dest_table = Table('table')
        dest_table.columns = {'col_1': Column('col_1', 'varchar(10)',
                                              not_null=False)}
        self.base_db.tables = {'table': source_table}
        self.dest_db.tables = {'table': dest_table}
        
        diff = DatabaseDiff(self.base_db, self.dest_db)
        self.assertTrue('table' in diff.tables)
        self.assertState(diff.tables['table'], ALTERED)
        self.assertTrue('col_1' in diff.tables['table'].columns)
        self.assertState(diff.tables['table'].columns['col_1'], ALTERED)

# find_accessing_procedures tests

def test_find_accessing_procedures1(self):
    procedures = find_accessing_procedures(self.db.tables['one_column'])
    names = ['get_accessing_procedures_1()']
    self.assertEqual(set(procedures), set(names))
    
def test_find_accessing_procedures2(self):
    procedures = find_accessing_procedures(self.db.tables['SoMe_TaBlE'])
    names = ['get_accessing_procedures_4()']
    self.assertEqual(set(procedures), set(names))
    
def test_find_accessing_procedures3(self):
    procedures = find_accessing_procedures(self.db.tables['some_table'])
    names = ['get_accessing_procedures_2()', 'get_accessing_procedures_3()']
    self.assertEqual(set(procedures), set(names))

    # find_accessing_procedures tests
    
def test_find_accessing_procedures(self):
    procedures = find_accessing_procedures(self.db.tables['one_column'])
    names = ['get_accessing_procedures_1']
    self.assertEqual(set(procedures), set(names))

if __name__ == "__main__":
    main()
